import numpy as np
import pandas as pd
import sys
import anndata
import scanpy as sc
import scipy.stats
import scikit_posthocs as sp


def color_sig_red(value):
    if float(value) < .05:
        return 'background-color: mistyrose;'
    return 'background-color: white;'


def kruskal_between_groups(df, value_col='value', group_col='Group', dunn=True):
    groups = np.sort(df[group_col].unique())
    
    k = scipy.stats.kruskal(*[df.loc[df[group_col]==g, value_col].tolist() for g in groups]).pvalue
    print(f"Kruskal-Wallis p-value: {k:.2g}")

    if (k < .05) & (dunn):
        d = pd.DataFrame(sp.posthoc_dunn([df.loc[df[group_col]==g, value_col].tolist() for g in groups], p_adjust = 'bonferroni'))
        d.columns = d.index = groups
        for col in d.columns:
            d[col] = d[col].map('{:,.2g}'.format)
        print("Dunn post-hoc test")
        return d.style.applymap(color_sig_red)


def diff_abundance(ad, group_col, reference='rest', count_de=True):
    sc.tl.rank_genes_groups(ad, groupby=group_col, reference=reference, method="wilcoxon")
    groups = np.sort(ad.obs[group_col].unique())
    nonreference_groups = [x for x in groups if x!=reference]

    df = pd.DataFrame()
    for group in nonreference_groups:
        curr_df = sc.get.rank_genes_groups_df(ad, group)
        curr_df[group_col] = group
        df = pd.concat([df, curr_df])
        if count_de:
            n_de = np.sum(curr_df['pvals_adj'] < .05)
            print(f'{n_de} features are differentially expressed in {group}')

    return df


def compare_motif_positions(sites, motifs):
    sites.index = sites.index.astype(str)
    sites.pos = sites.pos.astype(int)
    sites = sites[sites.Matrix_id.isin(motifs)]
    sites = sites.merge(all[['label', 'Group', 'SeqID']], left_index=True, right_on='SeqID')
    sites = sites[['Matrix_id', 'pos', 'Group']]
    df = pd.DataFrame()
    for m in motifs:
        print(m)
        kruskal_between_groups(sites[sites.Matrix_id==m], value_col='pos', group_col='Group')


def proportion_between_groups(df, value_col, ref_value=None, group_col='Group', ref_group='Test Set'):

    groups = np.sort(df[group_col].unique())
    nonreference_groups = [x for x in groups if x!=ref_group]

    if ref_value is None:
        freqs = df.groupby(group_col)[value_col].value_counts().unstack()
        freqs = {c:freqs.loc[c].values for c in freqs.index}
    
    ref_prop = freqs[ref_group]/sum(freqs[ref_group])
    for group in nonreference_groups:
        pval = scipy.stats.chisquare(freqs[group], f_exp=ref_prop*sum(freqs[group])).pvalue
        print(f"Chi-squared p-value for {value_col}={ref_value} in group {group} vs. {ref_group}: {pval:2g}")